iT邦幫忙

2023 iThome 鐵人賽

DAY 7
0

儲存模型

儲存序列化模型內部狀態和權重,命名為model.pth,程式碼如下:

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

儲存成功會出現以下文字:
https://ithelp.ithome.com.tw/upload/images/20230912/20153503HOEKwBiF1V.png

載入模型

載入模型通常會在第一次啟動程式的時候執行,以下程式碼為載入模型的架構和權重:

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model.pth"))

成功載入會出現以下文字:
https://ithelp.ithome.com.tw/upload/images/20230912/20153503CTOvY6I0As.png

模型推論

模型輸出的結果是class0~9的數值,因此要建立原本類別的順序的列表,假如模型輸出結果是class0的數值最大,代表預測的結果是T-shirt/top

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
# 取得第0筆資料的圖像和答案
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

看起來AI預測正確,若有興趣的人可以嘗試將以上程式碼的 x, y = test_data[這個][0], test_data[這個][1]第一維改成其他數字,看看AI在其他圖片有沒有預測成功。
https://ithelp.ithome.com.tw/upload/images/20230912/20153503qKwOwuJPg2.png

結語:
今天介紹模型如何儲存到硬碟中,如果重開機或是要在另外一台電腦啟動就可以使用預先儲存好的權重檔案載入模型,從Day4到今天將整個AI模型的流程跑過一次,接下來會進入細節的講解。


上一篇
Day 6 調教你的AI模型(Pytorch)
下一篇
Day8 深度學習常提到的張量是什麼?
系列文
30天把AI知識傳授給女友30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言